{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import tensorflow as tf\n",
    "from tensorflow import gfile\n",
    "from tensorflow import logging\n",
    "import pprint\n",
    "import cPickle\n",
    "import numpy as np\n",
    "\n",
    "model_file = \"../../image_caption_data/inception_v3_graph_def.pb\"\n",
    "input_description_file = \"../../image_caption_data/results_20130124.token\"\n",
    "input_img_dir = \"../../image_caption_data/flickr30k_images/\"\n",
    "output_folder = \"../../image_caption_data/feature_extraction_inception_v3\"\n",
    "\n",
    "batch_size = 100\n",
    "\n",
    "if not gfile.Exists(output_folder):\n",
    "    gfile.MakeDirs(output_folder)\n",
    "\n",
    "\n",
    "def parse_token_file(token_file):\n",
    "    \"\"\"Parses token file.\"\"\"\n",
    "    img_name_to_tokens = {}\n",
    "    with gfile.GFile(token_file, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    for line in lines:\n",
    "        img_id, description = line.strip('\\r\\n').split('\\t')\n",
    "        img_name, _ = img_id.split('#')\n",
    "        img_name_to_tokens.setdefault(img_name, [])\n",
    "        img_name_to_tokens[img_name].append(description)\n",
    "    return img_name_to_tokens\n",
    "\n",
    "img_name_to_tokens = parse_token_file(input_description_file)\n",
    "all_img_names = img_name_to_tokens.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.info(\"num of all images: %d\" % len(all_img_names))\n",
    "pprint.pprint(img_name_to_tokens.keys()[0:10])\n",
    "pprint.pprint(img_name_to_tokens['2778832101.jpg'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_pretrained_inception_v3(model_file):\n",
    "    with gfile.FastGFile(model_file, \"rb\") as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "        _ = tf.import_graph_def(graph_def, name=\"\")\n",
    "load_pretrained_inception_v3(model_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_batches = int(len(all_img_names) / batch_size)\n",
    "if len(all_img_names) % batch_size != 0:\n",
    "    num_batches += 1\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    second_to_last_tensor = sess.graph.get_tensor_by_name(\"pool_3:0\")\n",
    "    for i in range(num_batches):\n",
    "        batch_img_names = all_img_names[i*batch_size: (i+1)*batch_size]\n",
    "        batch_features = []\n",
    "        for img_name in batch_img_names:\n",
    "            img_path = os.path.join(input_img_dir, img_name)\n",
    "            logging.info(\"processing img %s\" % img_name)\n",
    "            if not gfile.Exists(img_path):\n",
    "                raise Exception(\"%s doesn't exists\" % img_path)\n",
    "            img_data = gfile.FastGFile(img_path, \"rb\").read()\n",
    "            feature_vector = sess.run(second_to_last_tensor,\n",
    "                                      feed_dict = {\"DecodeJpeg/contents:0\": img_data})\n",
    "            batch_features.append(feature_vector)\n",
    "        batch_features = np.vstack(batch_features)\n",
    "        output_filename = os.path.join(output_folder, \"image_features-%d.pickle\" % i)\n",
    "        logging.info(\"writing to file %s\" % output_filename)\n",
    "        with gfile.GFile(output_filename, 'w') as f:\n",
    "            cPickle.dump((batch_img_names, batch_features), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
